{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a07356b3-744a-4319-9fec-cd62f37fa865",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ff05a38-2211-4240-a9db-0d79c813ab99",
   "metadata": {},
   "source": [
    "secretflow提供了多种预处理工具，便于用户对数据进行处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25ec1569-f9f7-4f27-90b8-a6c7feab28e2",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 前置准备\n",
    "\n",
    "初始化secretflow，创建两个参与方alice和bob。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83e1596d-8ca1-40ae-9681-7254c563ff7e",
   "metadata": {},
   "source": [
    "> 💡 在使用预处理之前，您可能需要先了解secretflow的[DataFrame](./DataFrame.ipynb)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9ad74320-2c3a-4c86-aea4-6688d96d2230",
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "sf.init(['alice', 'bob'])\n",
    "alice = sf.PYU('alice')\n",
    "bob = sf.PYU('bob')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94c83c7b-417a-4772-9de1-2efc589cd89f",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86168ad6-2fe0-4410-b59c-fd65cbe8ea9b",
   "metadata": {},
   "source": [
    "这里我们使用[iris](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html) 作为示例数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d7d70b8-2d12-40c0-891e-d42cbd567cab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>virginica</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  5.1               3.5                1.4               0.2   \n",
       "1                  4.9               NaN                1.4               0.2   \n",
       "2                  4.7               3.2                1.3               0.2   \n",
       "3                  4.6               3.1                1.5               0.2   \n",
       "4                  5.0               3.6                1.4               0.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "145                6.7               3.0                5.2               2.3   \n",
       "146                6.3               2.5                5.0               1.9   \n",
       "147                6.5               3.0                5.2               2.0   \n",
       "148                6.2               3.4                5.4               2.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "        target  \n",
       "0       setosa  \n",
       "1       setosa  \n",
       "2       setosa  \n",
       "3       setosa  \n",
       "4       setosa  \n",
       "..         ...  \n",
       "145  virginica  \n",
       "146  virginica  \n",
       "147  virginica  \n",
       "148  virginica  \n",
       "149  virginica  \n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "\n",
    "iris = load_iris(as_frame=True)\n",
    "data = pd.concat([iris.data, iris.target], axis=1)\n",
    "\n",
    "# 为了便于后续的效果展示，这里我们先将部分数据置为None\n",
    "data.iloc[1, 1] = None\n",
    "data.iloc[100, 1] = None\n",
    "\n",
    "# 将target还原为初始名称。\n",
    "data['target'] = data['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica' })\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d627e573-65cd-4bb7-ae33-e1419c9e5a6f",
   "metadata": {},
   "source": [
    "创建水平DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e7fd840d-3ef7-412e-bd72-9e32cd435cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对Iris数据集进行水平切分\n",
    "h_alice, h_bob = data.iloc[:70, :], data.iloc[70:, :]\n",
    "\n",
    "# 保存数据集到文件\n",
    "import tempfile\n",
    "\n",
    "_, h_alice_path = tempfile.mkstemp()\n",
    "_, h_bob_path = tempfile.mkstemp()\n",
    "h_alice.to_csv(h_alice_path, index=False)\n",
    "h_bob.to_csv(h_bob_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "453384aa-08dc-4b13-aef4-30323d8b64b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.data.horizontal import read_csv as h_read_csv\n",
    "from secretflow.security.aggregation import PlainAggregator\n",
    "from secretflow.security.compare import PlainComparator\n",
    "\n",
    "hdf = h_read_csv({alice: h_alice_path, bob: h_bob_path}, \n",
    "                 aggregator=PlainAggregator(alice), \n",
    "                 comparator=PlainComparator(alice))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e109ddee-3998-4c1a-8bb6-797cfe9a31f0",
   "metadata": {},
   "source": [
    "创建垂直DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "90206eca-f690-412d-9c09-e70e6830f6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对Iris数据集进行垂直切分\n",
    "v_alice, v_bob = data.iloc[:, :2], data.iloc[:, 2:]\n",
    "\n",
    "# 保存数据集到文件\n",
    "_, v_alice_path = tempfile.mkstemp()\n",
    "_, v_bob_path = tempfile.mkstemp()\n",
    "v_alice.to_csv(v_alice_path, index=False)\n",
    "v_bob.to_csv(v_bob_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9b5ffcc7-8025-453c-94de-bfafe0839461",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.data.vertical import read_csv as v_read_csv\n",
    "from secretflow.security.aggregation import PlainAggregator\n",
    "from secretflow.security.compare import PlainComparator\n",
    "\n",
    "vdf = v_read_csv({alice: v_alice_path, bob: v_bob_path})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b346a56c-6a5b-4b59-8c36-99344febdb51",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 预处理\n",
    "\n",
    "secretflow提供了缺失值填充、归一化、OneHot编码、标签编码等功能，使用体感上和sklearn的preprocessing类似。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c22e1dc7-508c-487b-8434-fb46f2cd0bcf",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 缺失值填充\n",
    "\n",
    "DataFrame提供了fillna方法，可以对缺失值填充，使用方式和pandas一致。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f53eb206-7ba5-44ab-9b70-713af21ae28c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "148"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 填充前，sepal width (cm)有两个位置缺失。\n",
    "vdf.count()['sepal width (cm)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b8c94841-73b3-4eb8-913d-af2f29c15ee2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 对sepal width (cm)进行缺失值填充\n",
    "vdf.fillna(value={'sepal width (cm)': 10}).count()['sepal width (cm)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "005e1476-c6bb-4a0f-a552-44ffd171e158",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "148"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 填充前，sepal width (cm)有两个位置缺失。\n",
    "hdf.count()['sepal width (cm)']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "df0c69f9-469f-43b0-8a90-ec3589298087",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 对sepal width (cm)进行缺失值填充\n",
    "hdf.fillna(value={'sepal width (cm)': 10}).count()['sepal width (cm)']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd909a7a-0cde-4cdc-b8a6-bbb09a909773",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 最大最小归一化\n",
    "\n",
    "secretflow提供了MinMaxScaler用作归一化，MinMaxScaler的输入输出都是DataFrame。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "51043015-d6e8-4461-b205-eeb3a966fe83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.preprocessing import MinMaxScaler\n",
    "\n",
    "scaler = MinMaxScaler()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0d720471-55c4-418c-9ea7-bef4aef78666",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min:  0.0\n",
      "Max:  1.0\n"
     ]
    }
   ],
   "source": [
    "# 对水平DataFrame的sepal length列进行归一化\n",
    "# 等价于：\n",
    "# scaler.fit(hdf['sepal length (cm)'])\n",
    "# scaled_target = scaler.transform(hdf['sepal length (cm)'])\n",
    "scaled_sepal_len_h = scaler.fit_transform(hdf['sepal length (cm)'])\n",
    "\n",
    "print('Min: ', scaled_sepal_len_h.min()[0])\n",
    "print('Max: ', scaled_sepal_len_h.max()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7667589f-f346-4674-9f9d-4c93ce889b2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Min:  0.0\n",
      "Max:  1.0\n"
     ]
    }
   ],
   "source": [
    "# 上述操作也可以同样作用于垂直DataFrame。\n",
    "scaled_sepal_len_v = scaler.fit_transform(vdf['sepal length (cm)'])\n",
    "\n",
    "print('Min: ', scaled_sepal_len_v.min()[0])\n",
    "print('Max: ', scaled_sepal_len_v.max()[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d48a464-10df-4a86-82a3-e51fb7d10579",
   "metadata": {
    "tags": []
   },
   "source": [
    "### OneHot编码\n",
    "\n",
    "secretflow提供了OneHotEncoder用于OneHot编码，OneHotEncoder的输入输出都是DataFrame。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fe2023a4-6e6b-4ed8-892a-78df980e7ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.preprocessing import OneHotEncoder\n",
    "\n",
    "onehot_encoder = OneHotEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "83b9ad7d-130a-4616-a21d-6ff5e0398a01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target_setosa', 'target_versicolor', 'target_virginica'], dtype='object')\n",
      "Min: \n",
      " target_setosa        0.0\n",
      "target_versicolor    0.0\n",
      "target_virginica     0.0\n",
      "dtype: float64\n",
      "Max: \n",
      " target_setosa        1.0\n",
      "target_versicolor    1.0\n",
      "target_virginica     1.0\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# 对水平DataFrame的target列进行归一化\n",
    "# 等价于：\n",
    "# onehot_encoder.fit(hdf['target'])\n",
    "# onehot_encoder.transform(hdf['target'])\n",
    "onehot_target_h = onehot_encoder.fit_transform(hdf['target'])\n",
    "\n",
    "print('Columns: ', onehot_target_h.columns)\n",
    "print('Min: \\n', onehot_target_h.min())\n",
    "print('Max: \\n', onehot_target_h.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4cdddf7f-9d19-4200-a4c8-3582ef43d0f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target_setosa', 'target_versicolor', 'target_virginica'], dtype='object')\n",
      "Min: \n",
      " target_setosa        0.0\n",
      "target_versicolor    0.0\n",
      "target_virginica     0.0\n",
      "dtype: float64\n",
      "Max: \n",
      " target_setosa        1.0\n",
      "target_versicolor    1.0\n",
      "target_virginica     1.0\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# 上述操作也可以同样作用于垂直DataFrame。\n",
    "onehot_target_v = onehot_encoder.fit_transform(vdf['target'])\n",
    "\n",
    "print('Columns: ', onehot_target_v.columns)\n",
    "print('Min: \\n', onehot_target_v.min())\n",
    "print('Max: \\n', onehot_target_v.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d2434aa-9f1f-40ed-b513-78c179e24635",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 标签编码\n",
    "\n",
    "secretflow提供了LabelEncoder用于标签编码，LabelEncoder的输入输出都是DataFrame。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3c55619d-c12a-444d-b3a0-ab05d71dacc4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from secretflow.preprocessing import LabelEncoder\n",
    "\n",
    "label_encoder = LabelEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "49df4d40-c1c5-498e-adf4-bb907b7db4cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target'], dtype='object')\n",
      "Min: \n",
      " target    0\n",
      "dtype: object\n",
      "Max: \n",
      " target    2\n",
      "dtype: object\n"
     ]
    }
   ],
   "source": [
    "# 对水平DataFrame的target列进行归一化\n",
    "# 等价于：\n",
    "# scaler.fit(hdf['target'])\n",
    "# label_encoder.transform(hdf['target'])\n",
    "encoded_label_h = label_encoder.fit_transform(hdf['target'])\n",
    "\n",
    "print('Columns: ', encoded_label_h.columns)\n",
    "print('Min: \\n', encoded_label_h.min())\n",
    "print('Max: \\n', encoded_label_h.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7294d720-15c1-4027-b0f1-bc5d4f9e35ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Columns:  Index(['target'], dtype='object')\n",
      "Min: \n",
      " target    0\n",
      "dtype: object\n",
      "Max: \n",
      " target    2\n",
      "dtype: object\n"
     ]
    }
   ],
   "source": [
    "# 上述操作也可以同样作用于垂直DataFrame。\n",
    "encoded_label_v = label_encoder.fit_transform(vdf['target'])\n",
    "\n",
    "print('Columns: ', encoded_label_v.columns)\n",
    "print('Min: \\n', encoded_label_v.min())\n",
    "print('Max: \\n', encoded_label_v.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dd4ddd5",
   "metadata": {},
   "source": [
    "## 收尾"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "831cc32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 清理临时文件\n",
    "\n",
    "import os\n",
    "\n",
    "try:\n",
    "    os.remove(h_alice_path)\n",
    "    os.remove(h_bob_path)\n",
    "except OSError:\n",
    "    pass\n",
    "\n",
    "try:\n",
    "    os.remove(v_alice_path)\n",
    "    os.remove(v_bob_path)\n",
    "except OSError:\n",
    "    pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
